#!/usr/bin/env python3

# Python script to run and analyse MMS test

#requires: all_tests

from __future__ import division
from __future__ import print_function
from builtins import str

from boututils.run_wrapper import shell, shell_safe, launch_safe
from boutdata.collect import collect

import pickle
from numpy import sqrt, max, abs, mean, array, log, zeros

from os.path import join

import time


print("Making MMS test")
shell_safe("make > make.log")

#nxlist = [256, 128, 64, 32, 16, 8] # do in reverse order to save disk space
nxlist = [32, 64]

nout = 1
timestep = 0.01

nproc = 8
mxg = 2

# Input directories
inputs = ["orthogonal", "shifted", "shear", "nonorthogonal","nonorth_shifted", "nonorth_shear", "polshear", "polshear_shifted", "polshear_shear"]
input_legend_names = ["orthogonal", "z-shifted", "z-shear", "y-shifted", "y-z shifted", "y-shifted z-shear", "y-shear", "y-shear z-shifted", "y-z shear"]

success = True

ninputs = len(inputs)
tot_error = zeros((ninputs,len(nxlist)))

i=0
for dir in inputs:
    directory = "./"+dir
    print("Running test in '%s'" % (directory))

    error_2    = []  # The L2 error (RMS)
    error_inf = []  # The maximum error

    for nx in nxlist:
        args = "-d "+directory+" nout="+str(nout)+" timestep="+str(timestep)+" mesh:ny="+str(nx)+" mesh:nz="+str(nx)+" mesh:nx="+str(nx+2*mxg)+" mxg="+str(mxg)

        print("  Running with " + args)

        # Delete old data
        shell("rm "+directory+"/BOUT.dmp.*.nc")

        # Command to run
        cmd = "./fieldalign "+args

        # Launch using MPI
        s, out = launch_safe(cmd, nproc=nproc, pipe=True)

        # Save output to log file
        f = open("run.log."+str(nx), "w")
        f.write(out)
        f.close()

        # Collect data
        E_f = collect("E_f", tind=[nout,nout], path=directory,info=False)

        # Average error over domain, not including guard cells
        l2 = sqrt(mean(E_f**2))
        linf = max(abs( E_f ))

        error_2.append( l2 )
        error_inf.append( linf )

        print("  -> Error norm: l-2 %f l-inf %f" % (l2, linf))

    # Calculate grid spacing
    dx = 1. / (array(nxlist) - 2.)

    # Calculate convergence order

    order = log(error_2[-1] / error_2[-2]) / log(dx[-1] / dx[-2])

    print("Convergence order = %f" % (order))
    if 1.8 < order < 2.2:
        pass
    else:
        success = False
        print("=> FAILED\n")

    try:
        import matplotlib.pyplot as plt
        # plot errors

        plt.figure()

        plt.plot(dx, error_2, '-o', label=r'$l^2$')
        plt.plot(dx, error_inf, '-x', label=r'$l^\infty$')

        plt.plot(dx, error_2[-1]*(dx/dx[-1])**order, '--', label="Order %.1f"%(order))

        plt.legend(loc="upper left")
        plt.grid()

        plt.yscale('log')
        plt.xscale('log')

        plt.xlabel(r'Mesh spacing $\delta x$')
        plt.ylabel("Error norm")

        plt.savefig(join(dir,"norm.pdf"))

        #plt.show()
        plt.close()

    except:
        pass

plt.ylabel("Error norm")
plt.grid()
plt.yscale('log')
plt.xscale('log')

marker = ["-o","-p","-x","-v","-*","-^","-D","-s","-8"]
for i in range(ninputs):
    plt.plot(dx,tot_error[i,:],marker[i],label=str(input_legend_names[i]))
plt.legend(loc="upper left")
plt.savefig(join("./","norm.pdf"))
plt.close()

if success:
    print(" => All tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
